####################################################
#Author: Kelli Marquardt
#Purpose: Estimate Model (stage1 & stage2) and produce related tables/figures

# Inputs:
#- data/est_dat_fake.csv 

# Outputs:
#- data/intermediate/pcp_fe_dat.csv 
#- data/intermediate/main_est.csv 
#- output/figures/fig_4.png, fig_a1.png 
#- output/tables/tab_5.txt, tab_7.txt, tab_8.txt, tab_a2.txt, tab_a4.txt, tab_a6.txt, tab_a7.txt 

####################################################

############################
#0 load required packages
############################
rm(list = ls(all.names = TRUE))

#load packages 
library(MASS)
library(dplyr) 
library(fastDummies)
library(ggplot2)
library(sandwich)
library(lmtest)
library(cowplot)



############################
#0 Define functions used in estimation and simulation 
############################

#demean numeric variables
demean=function(x){
  if(class(x)=="character"){
    return(x)
  }else{
    h=x-mean(x, na.rm=T)
    return(h)
  }
} 

#function to recover the model parameters
recover_params=function(mu, qhat, xbar, bhat, ahat,
                        sigma_lower=1e-8, sigma_upper=1e3){
  
  zq=qnorm(qhat)
  phiq=dnorm(zq)
  k_tilde=(xbar-mu)*qhat/phiq
  
  #rho as a fn of sigma (from bhat and from xbar)
  rho_from_bhat=function(sig){
    return_val=(bhat*sig)/sqrt(1+((bhat*sig)^2))
    return(return_val)
  } 
  rho_from_xbar=function(sig){
    return_val=k_tilde*sqrt(1+(sig^2))/(sig^2)
    return(return_val)
  }
  
  
  #define root function 
  f_root=function(sig) {
    return_val=rho_from_xbar(sig) - rho_from_bhat(sig)
    return(return_val)
  }
  
  #bracket a root for uniroot 
  lo= sigma_lower
  hi=sigma_upper
  flo=f_root(lo)
  fhi=f_root(hi)
  
  # If not bracketed, give error 
  if (!is.finite(flo) || !is.finite(fhi) || flo * fhi > 0) {
    stop("Could not bracket a root for sigma. Try different sigma_lower/sigma_upper.")
  }
  
  #recover sigma_hat 
  sig_hat = uniroot(f_root, lower = lo, upper = hi)$root
  
  #recover rho, cbar, and tau 
  rho_hat= rho_from_bhat(sig_hat)
  cbar_hat= mu - zq * sqrt(1 + (sig_hat^2))
  tau_hat= (1 - rho_hat) * mu - ahat * sig_hat * sqrt(1 - (rho_hat^2) ) 
  
  
  return(list(rho = rho_hat, sig = sig_hat, tau = tau_hat, cbar = cbar_hat))
}


#simulates data from model parameter estimates (and returns relevant moments)
simulateDx=function(param_vec, N=10000, seed=as.integer(Sys.Date())){
  mu_s=param_vec$mu
  sig_s=param_vec$sig
  cbar_s=param_vec$c_bar
  rho_s=param_vec$rho
  tau_s=param_vec$tau
  
  #simulate vi, xi, ci 
  set.seed(seed)
  Mu_sim= c(mu_s, mu_s, cbar_s)
  Sigma_sim = matrix(c(
    sig_s^2, rho_s*sig_s^2, 0,
    rho_s*sig_s^2, sig_s^2, 0,
    0,0,1),
    nrow = 3, byrow = TRUE)
  
  pat_dat_sim=mvrnorm(N, Mu_sim, Sigma_sim)
  pat_dat_sim=as.data.frame(pat_dat_sim)
  colnames(pat_dat_sim)=c("vi","xi", "ci")
  
  
  #get a random uniform for probDi
  unif01=runif(N, 0, 1)
  pat_dat_sim$ui=unif01
  
  #simulate Q and D
  pat_dat_sim=pat_dat_sim%>%
    mutate(Qi=ifelse(vi>ci, 1, 0),
           probDi_x=pnorm((rho_s*xi+(1-rho_s)*mu_s-tau_s)/
                            (sig_s*sqrt(1-(rho_s^2)))),
           Di=ifelse(probDi_x>ui & Qi==1, 1, 0))
  
  #get moments to return 
  dx_rate=mean(pat_dat_sim$Di)
  q_rate=mean(pat_dat_sim$Qi)
  xbar_c=mean(pat_dat_sim$xi[which(pat_dat_sim$Qi==1)])
  dx_c=mean(pat_dat_sim$Di[which(pat_dat_sim$Qi==1)])
  
  dx_sum=sum(pat_dat_sim$Di)
  q_sum=sum(pat_dat_sim$Qi)
  xsum_c=sum(pat_dat_sim$xi[which(pat_dat_sim$Qi==1)])
  dxsum_c=sum(pat_dat_sim$Di[which(pat_dat_sim$Qi==1)])
  
  
  return(list(dx_rate=dx_rate, 
              q_rate=q_rate, 
              xbar_c=xbar_c, 
              dx_c=dx_c,
              
              dx_sum=dx_sum,
              q_sum=q_sum, 
              xsum_c=xsum_c, 
              dxsum_c=dxsum_c))
}

#set seed and n for simulations 
seed_opt=12345
n_opt=10000

#################################################
#Step 1: read in data and subset to step1_dat and step2_dat
#################################################

est_dat = read.csv(file.path("..", "data", "est_dat_fake.csv"), stringsAsFactors = FALSE)

step1_dat=est_dat%>%
  select(pat_id, male, hispanic, white, Med, Com, birth_year, 
         first_pcp_id, Qi, xi, 
         ends_with("_uptoQ1"))

step2_dat=est_dat%>%
  select(pat_id, male, Qi, xi, Di)

#################################################
#Step 2: Run first stage to obtain estimates mu_m and mu_f
#################################################


###################
#2a- prep step1 data 
###################

#remove the suffix _uptoQ1 here
step1_dat=step1_dat%>%
  rename_with( ~ sub("_uptoQ1","", .x),
               ends_with("_uptoQ1"))

## define age and age^2, total number of appointments, and birthyear fixed effects
step1_dat=step1_dat%>%
  rename(age=age_mean)%>%
  mutate(age2=age^2)%>%
  mutate(numapt=year_2014+year_2015+year_2016+year_2017)%>%
  dummy_cols(select_columns = "birth_year")



## subset to keep necessary variables for estimation
step1_dat=step1_dat%>%
  select(pat_id, first_pcp_id, Qi, male, xi, 
         white, hispanic, numapt, numdocs, age, age2,
         starts_with("year"), starts_with("birth_year_"),
         Med, Com, psych_doc, well, behav)


##
#take step 1 dat and   
#remove patients with a first pcp that does not have enough patients (n_gender <2)
#remove patients with a first pcp that has Q_gender %in% c(0,1)
step1_dat=step1_dat%>%
  group_by(first_pcp_id)%>%
  mutate(num_male=sum(male),
         num_female=n()-num_male)%>%
  filter(num_male>2 & num_female>2)%>%
  mutate(meanQ_male=sum(Qi*male)/num_male,
         meanQ_female=sum(Qi*(1-male))/num_female)%>%
  filter(!(meanQ_male %in% c(0,1) | meanQ_female %in% c(0,1)))%>%
  ungroup()%>%
  select(-c(num_male, num_female, meanQ_male, meanQ_female))


##demean the relevant first stage variables in a new dataset 
dat_fs=step1_dat%>%
  select(-c(first_pcp_id, Qi, male, xi))%>%
  na.omit()%>%
  mutate(across(-pat_id, ~ demean(.x)))


#add back in the male and pcp id to create pcp by gender indicator 
dat_fs=step1_dat%>%
  select(pat_id, Qi, first_pcp_id, male)%>% 
  right_join(dat_fs, by="pat_id")%>%
  mutate(pcpg=paste0(first_pcp_id,  "_", male))


###################
#2b- estimate P(Qi) and pull out pcp-by-gender fixed effects
###################

#define function 
reg_fs = as.formula(paste(paste("Qi", "~",""),
                          paste(names(dat_fs)[!names(dat_fs) %in%  c("pat_id","Qi","first_pcp_id","male","pcpg")],
                                collapse = " + ")," + as.factor(pcpg)-1",
                          collapse=""))

## estimate first stage 
step1_reg=glm(reg_fs, family = binomial(link = "probit"), data=dat_fs)

#pull out pcp by gender fixed effects and std errors 
pcp_fe=as.data.frame(coef(summary(step1_reg)))
pcp_fe$varname=rownames(pcp_fe)
pcp_fe=pcp_fe%>%
  filter(substr(varname, 1, 15)=="as.factor(pcpg)")%>%
  mutate(pcpg=substr(varname, 16, nchar(varname)))%>%
  select(pcpg, Estimate, `Std. Error`)%>%
  mutate(fe_est=pnorm(Estimate))%>%
  rename(fe_se=`Std. Error`)%>% #rename std error
  select(pcpg, fe_est, fe_se)
  
#merge in the actual pcp id, and male indicator
pcp_fe=dat_fs%>%
  select(first_pcp_id, male, pcpg)%>%
  distinct()%>%
  right_join(pcp_fe, by="pcpg")


###################
#2c- merge in xi and conduct extrapolations of xi on pcp-by-gender fe
###################

#######
#first need to get mean of xi (and se) for each pcpg 
xi_est_dat=step1_dat%>%
  filter(Qi==1 )%>%
  select(pat_id, male, xi, first_pcp_id)%>%
  left_join(pcp_fe, by=c("first_pcp_id","male"))%>%
  filter(!is.na(fe_est))

xi_est_dat=lm(xi ~ as.factor(pcpg)-1 ,  data=xi_est_dat)
xi_est_dat=as.data.frame(coef(summary(xi_est_dat)))
xi_est_dat$varname=rownames(xi_est_dat)
xi_est_dat=xi_est_dat%>%
  filter(substr(varname, 1, 15)=="as.factor(pcpg)")%>%
  mutate(pcpg=substr(varname, 16, nchar(varname)))%>%
  select(pcpg, Estimate, `Std. Error`)%>%
  mutate(xi_est=Estimate)%>% 
  rename(xi_se=`Std. Error`)%>% 
  select(pcpg, xi_est, xi_se)

#merge back in to pcp_fe and define weights to get extrap_dat
extrap_dat=pcp_fe%>%
  left_join(xi_est_dat, by="pcpg")%>%
  mutate(wt=1/(xi_se))

#####
#quick aside- output extrap_dat dataset along with N_pat and N_patQ1 for appendix table 
pcp_fe_dat=dat_fs%>%
  group_by(first_pcp_id, male)%>%
  summarise(npat=n(), 
         npatQ=sum(Qi),  .groups = "drop")%>%
  left_join(extrap_dat, by=c("first_pcp_id","male"))%>%
  select(-c(pcpg, wt))

#save to csv file (data/intermediate)
write.csv(pcp_fe_dat,
          file = file.path("..", "data", "intermediate", "pcp_fe_dat.csv"),
          row.names = FALSE)

rm(pcp_fe_dat)
#####


#get linear and exponential extrapolations 
lin_coef_m =summary(lm(xi_est ~ fe_est, 
                     data=extrap_dat, subset = (male==1),
                     weights = wt))$coef
lin_coef_f =summary(lm(xi_est ~ fe_est, 
                       data=extrap_dat, subset = (male==0),
                       weights = wt))$coef
exp_coef_m=summary(nls(xi_est ~ b0*exp(b1*fe_est),
                     data=extrap_dat, subset = (male==1),
                     weights=wt, start = list(b0 = 0.5, b1 = -0.5)))$coef
exp_coef_f=summary(nls(xi_est ~ b0*exp(b1*fe_est),
                       data=extrap_dat, subset = (male==0),
                       weights=wt, start = list(b0 = 0.5, b1 = -0.5)))$coef

#######
#print extrapolation tables 
#######

#linear
coef_m=lin_coef_m
coef_f=lin_coef_f
mu_m=coef_m[1, "Estimate"]+coef_m[2, "Estimate"]
mu_f=coef_f[1, "Estimate"]+coef_f[2, "Estimate"]
extrap_linear=paste("\\begin{tabular}{lcc}\n",
          "\\toprule\n",
          "& Male & Female \\\\\n",
          "& (1) & (2) \\\\\n",
          "\\midrule\n",
          "$\\widehat{\\alpha_0}$ &", 
          sprintf("%.3f", coef_m[1, "Estimate"]), "&", 
          sprintf("%.3f", coef_f[1, "Estimate"]), "\\\\\n",
          "& (", sprintf("%.3f", coef_m[1, "Std. Error"]), ") & (", 
          sprintf("%.3f", coef_f[1, "Std. Error"]), ")\\\\\n",
          "\\addlinespace\n",
          "$\\widehat{\\alpha_1}$ &", 
          sprintf("%.3f", coef_m[2, "Estimate"]), "&", 
          sprintf("%.3f", coef_f[2, "Estimate"]), "\\\\\n",
          "& (", sprintf("%.3f", coef_m[2, "Std. Error"]), ") & (", 
          sprintf("%.3f", coef_f[2, "Std. Error"]), ")\\\\\n",
          "\\midrule\n",
          "Fitted $\\mu_\\theta$ &", 
          sprintf("%.3f", mu_m), "&", 
          sprintf("%.3f", mu_f), "\\\\\n",
          "\\bottomrule\n",
          "\\end{tabular}\\\\\n", sep="")
write(extrap_linear, file = file.path("..", "output", "tables", "tab_a4.txt"))

#exponential
coef_m=exp_coef_m
coef_f=exp_coef_f
mu_m=coef_m[1, "Estimate"]*exp(coef_m[2, "Estimate"])
mu_f=coef_f[1, "Estimate"]*exp(coef_f[2, "Estimate"])
extrap_exp=paste("\\begin{tabular}{lcc}\n",
                    "\\toprule\n",
                    "& Male & Female \\\\\n",
                    "& (1) & (2) \\\\\n",
                    "\\midrule\n",
                    "$\\widehat{\\alpha_0}$ &", 
                    sprintf("%.3f", coef_m[1, "Estimate"]), "&", 
                    sprintf("%.3f", coef_f[1, "Estimate"]), "\\\\\n",
                    "& (", sprintf("%.3f", coef_m[1, "Std. Error"]), ") & (", 
                    sprintf("%.3f", coef_f[1, "Std. Error"]), ")\\\\\n",
                    "\\addlinespace\n",
                    "$\\widehat{\\alpha_1}$ &", 
                    sprintf("%.3f", coef_m[2, "Estimate"]), "&", 
                    sprintf("%.3f", coef_f[2, "Estimate"]), "\\\\\n",
                    "& (", sprintf("%.3f", coef_m[2, "Std. Error"]), ") & (", 
                    sprintf("%.3f", coef_f[2, "Std. Error"]), ")\\\\\n",
                    "\\midrule\n",
                    "Fitted $\\mu_\\theta$ &", 
                    sprintf("%.3f", mu_m), "&", 
                    sprintf("%.3f", mu_f), "\\\\\n",
                    "\\bottomrule\n",
                    "\\end{tabular}\\\\\n", sep="")
write(extrap_exp, file = file.path("..", "output", "tables", "tab_5.txt"))

#######
#print extrapolation figures 
#######
graph_hold=list()
title=c("Female","Male")
lin_coef=rbind(lin_coef_f, lin_coef_m)
exp_coef=rbind(exp_coef_f, exp_coef_m)

for(m in c(0,1)){
graph_hold[[m+1]]=extrap_dat%>%
  filter(male==m)%>%
  ggplot() +
  geom_point(aes(x = fe_est, y = xi_est, size=wt), 
             shape = ifelse(m == 0, 2, 1), 
             color = ifelse(m == 0, "orange2", "cyan4"), 
             size=1.5, 
             stroke=1)+ 
  #scale_size_continuous(range = c(0, 10))+
  geom_line(data = data.frame(x = seq(0, 1, by = 0.01), y = lin_coef[m*2+1, 1] + lin_coef[m*2+2, 1] * seq(0, 1, by = 0.01)),aes(x = x, y = y), color = ifelse(m == 0, "orange2", "cyan4"), linetype = "dashed", linewidth=.75) + 
  geom_line(data = data.frame(x = seq(0, 1, by = 0.01), y = exp_coef[m*2+1, 1] * exp(exp_coef[m*2+2, 1] * seq(0, 1, by = 0.01))),  aes(x = x, y = y), color = ifelse(m == 0, "orange2", "cyan4"), linewidth=.75) +  
  annotate("text",
           x = 0.9, y = 0.75, 
           label = paste0(title[m+1]),
           color = ifelse(m == 0, "orange2", "cyan4"),
           hjust = 1,
           size = 5
  )  +
  theme_classic() +
  theme(
    legend.position = "none",
    axis.title = element_blank(),
    axis.text=element_text(size=12)
  ) +
  coord_cartesian(xlim = c(0, 1), ylim = c(0, .8))
}


#put them together 
combined_fig=plot_grid( plotlist=rev(graph_hold), ncol = 2)+
  draw_label("Regression-Adjusted IPCP Referral Rate", x=0.5, y=0, vjust=1.5,  size=16) +
  draw_label("Observed ADHD Match Signal", x=0, y=0.5, vjust= -.5, angle=90, size=16)+
  theme(plot.margin = margin(0.25,1,1,1, "cm"))
#combined_fig
ggsave(plot = combined_fig,
       file = file.path("..", "output", "figures", "fig_4.png"),
       height = 5,
       width = 10,
       units = "in")

#clean up 
rm(coef_f, coef_m, combined_fig, graph_hold, title, 
   exp_coef, exp_coef_f, exp_coef_m,
   lin_coef,lin_coef_f, lin_coef_m,
   extrap_exp, extrap_linear
   )
rm(xi_est_dat)

###################
#2d- run selection tests for relevance and independence 
#(and histogram of LOO ref rate)
###################

#Independence: predicted Qi (without pcp fe) on PCP LOO ref rate 
#Relevance: Qi on PCP LOO ref rate 

#get predicted Qi (predictedval_nofe)
selection_test_dat=dat_fs%>%
  mutate(male_dm=male-mean(male)) #need to demean male as well 
  
predict_nofe_f = as.formula(paste(paste("Qi", "~",""),
                                  paste(names(dat_fs)[!names(dat_fs) %in%  c("pat_id","Qi","first_pcp_id","male","pcpg")],
                                        collapse = " + ")," +male + male_dm -1 ",
                                  collapse=""))
  
predict_nofe_reg=glm(predict_nofe_f, 
                     family = binomial(link = "probit"),
                     data=selection_test_dat)
  
#save predictedval_nofe to dataset 
selection_test_dat$predictedval_nofe = predict(predict_nofe_reg, 
                                               newdata = selection_test_dat, 
                                               type = "response")
#get LOO_resid
selection_test_dat=selection_test_dat%>%
  left_join(pcp_fe, by=c("first_pcp_id","male", "pcpg"))%>%
  mutate(resid=Qi-fe_est)%>%
  group_by(first_pcp_id, male)%>%
  mutate(sum_pat=length(unique(pat_id)),
         sum_resid=sum(resid))%>%
  ungroup()%>%
  mutate(Loo_resid=-1*(sum_resid-resid)/(sum_pat-1))%>% #multiply by negative 1 to interpret as referral rates
  mutate(Loo_resid_male=Loo_resid*male,
         Loo_resid_female=Loo_resid*(1-male)) 
  
#keep only necessary vars 
selection_test_dat=selection_test_dat%>%
  select(-c(first_pcp_id, pcpg, fe_se,  resid, sum_pat, sum_resid))
  
#plot histogram of Loo Ref Rate 
Loo_figure=selection_test_dat%>%
  mutate(Loo_resid=Loo_resid)%>%
  mutate(sex=ifelse(male==1, "Male","Female"))%>%
  mutate(sex=factor(sex, levels=c("Male","Female")))%>%
  select(Loo_resid, sex)%>%
  distinct()%>%
  ggplot() +
  geom_histogram(aes(x = Loo_resid, after_stat(density)),
                 bins=10,
                 color = "#000000", fill = "khaki3")+ 
  theme_classic() +
  facet_wrap(~sex)+
  theme(
    legend.position = "none",
    axis.title = element_text(size=12),
    axis.text=element_text(size=12),
    strip.text = element_text(size=12)
  ) +
  coord_cartesian(xlim = c(-1, 1))+
  xlab("Leave-One-Out Residualized PCP Referral Rates")
#Loo_figure
ggsave(plot = Loo_figure,
       file = file.path("..", "output", "figures", "fig_a1.png"),
       height = 5.8,
       width = 8)



#run relevance test 
#Relevance: Qi on PCP LOO ref rate 
reg_rel_test_f = as.formula(paste(paste("Qi", "~",""),
                                  paste(names(dat_fs)[!names(dat_fs) %in%  c("pat_id","Qi","first_pcp_id","male","pcpg")],
                                        collapse = " + ")," + Loo_resid_female +Loo_resid_male + male_dm",
                                  collapse=""))  

rel_full=lm(reg_rel_test_f, data=selection_test_dat)
rel_male=lm(reg_rel_test_f, data=selection_test_dat, 
            subset=(male==1))
rel_female=lm(reg_rel_test_f, data=selection_test_dat, 
            subset=(male==0))

#robust standard errors 
rel_full=coeftest(rel_full, vcov = vcovHC(rel_full, type = "HC1"))[
  c("Loo_resid_male","Loo_resid_female"), c("Estimate", "Std. Error")]
rel_male=coeftest(rel_male, vcov = vcovHC(rel_male, type = "HC1"))[
  c("Loo_resid_male"), c("Estimate", "Std. Error")]
rel_female=coeftest(rel_female, vcov = vcovHC(rel_female, type = "HC1"))[c("Loo_resid_female"), c("Estimate", "Std. Error")]


#run independence tests 
#Independence: predicted Qi (without pcp) on PCP LOO ref rate 
reg_ind_test_f = as.formula(paste(paste("predictedval_nofe", "~",""),
                                  paste(names(dat_fs)[!names(dat_fs) %in%  c("pat_id","Qi","first_pcp_id","male","pcpg")],
                                        collapse = " + ")," + Loo_resid_male + Loo_resid_female+male_dm ",
                                  collapse=""))

ind_full=lm(reg_ind_test_f, data=selection_test_dat)
ind_male=lm(reg_ind_test_f, data=selection_test_dat,
            subset=(male==1))
ind_female=lm(reg_ind_test_f, data=selection_test_dat,
            subset=(male==0))

#robust standard errors 
ind_full=coeftest(ind_full, vcov = vcovHC(ind_full, type = "HC1"))[
  c("Loo_resid_male","Loo_resid_female"), c("Estimate", "Std. Error")]
ind_male=coeftest(ind_male, vcov = vcovHC(ind_male, type = "HC1"))[
  c("Loo_resid_male"), c("Estimate", "Std. Error")]
ind_female=coeftest(ind_female, vcov = vcovHC(ind_female, type = "HC1"))[c("Loo_resid_female"), c("Estimate", "Std. Error")]

#######
#print relevance and independence tables 
#######

#need observations 
n_obs=c(nrow(selection_test_dat), sum(selection_test_dat$male))
n_obs=c(n_obs, n_obs[1]-n_obs[2])

selection_test_table=paste(
  "\\begin{tabular}{lccc}",
  "\\toprule",
  "& Total & Male & Female \\\\",
  "& (1) & (2) & (3) \\\\",
  "\\midrule",
  "\\multicolumn{4}{l}{\\textbf{Panel A: \\textit{Actual} Behavioral Assessment Indicator}} \\\\",
  "\\addlinespace",
  sprintf("Male PCP Referral Rate & %.3f & %.3f & - \\\\", rel_full[1,1], rel_male[1]),
  sprintf("& (%.3f) & (%.3f) & \\\\", rel_full[1,2], rel_male[2]),
  sprintf("Female PCP Referral Rate & %.3f & - & %.3f \\\\", rel_full[2,1], rel_female[1]),
  sprintf("& (%.3f) & & (%.3f) \\\\", rel_full[2,2], rel_female[2]),
  "\\hline",
  "\\addlinespace",
  "\\multicolumn{4}{l}{\\textbf{Panel B: \\textit{Predicted} Behavioral Assessment Indicator}} \\\\",
  "\\addlinespace",
  sprintf("Male PCP Referral Rate & %.3f & %.3f & -  \\\\", ind_full[1,1], ind_male[1]),
  sprintf("& (%.3f) & (%.3f) & \\\\", ind_full[1,2], ind_male[2]),
  sprintf("Female PCP Referral Rate & %.3f & - & %.3f \\\\", ind_full[2,1], ind_female[1]),
  sprintf("& (%.3f) & & (%.3f) \\\\", ind_full[2,2], ind_female[2]),
  "\\midrule",
  "Patient Demographics & Y & Y & Y \\\\",
  "Healthcare Utilization & Y & Y & Y \\\\",
  sprintf("Observations & %d & %d & %d \\\\", n_obs[1], n_obs[2], n_obs[3]),
  "\\bottomrule",
  "\\end{tabular}",
  sep = "\n"
)
write(selection_test_table, file = file.path("..", "output", "tables", "tab_a2.txt"))

#clean up 
rm(ind_full, ind_male, ind_female,
   rel_full, rel_male, rel_female,
   selection_test_dat, 
   step1_reg,  n_obs, 
   predict_nofe_f, predict_nofe_reg, 
   reg_ind_test_f, reg_rel_test_f, selection_test_table,
   Loo_figure)

rm(step1_dat, dat_fs, extrap_dat, pcp_fe,
    m,  reg_fs)

#################################################
#Step 3: Obtain the remaining parameter estimates given mu (and save)
#################################################

main_est=NA
mu_fm=c(mu_f, mu_m) 

for (m in c(0,1)){
  mu=mu_fm[m+1] # get mu for that m 
  
  #get dat_g and dat_condit
  dat_g=step2_dat%>%
    filter(male==m)%>%
    select(Qi,xi, Di)
  dat_condit=dat_g%>%
    filter(Qi==1)
  
  #get probit slope and constants 
  probit_output=glm(dat_condit$Di ~ dat_condit$xi, family=binomial(link="probit"))
  slope=probit_output$coefficients[2]
  cons=probit_output$coefficients[1]
  
  #back out the parameters- need xbar, qhat, sd limit 
  xbar=mean(dat_condit$xi)
  qhat=mean(dat_g$Qi)
  sd_max=sd(dat_condit$xi)*100
  
  #recover the parameters
  param_out=recover_params(mu=mu, qhat=qhat, xbar=xbar, 
                           bhat=slope, ahat=cons, 
                           sigma_upper = sd_max)
  
  rho=param_out$rho
  sig=param_out$sig
  tau=param_out$tau
  cbar=param_out$cbar
  
  
  #save as parm_vec and append to main_est
  parm_vec=c(m, mu, sig, cbar, rho, tau)
  names(parm_vec)=c("male","mu","sig", "c_bar","rho","tau")
  main_est=rbind(main_est, parm_vec)
  
  #clean up    
  rm(mu, sig, cbar, rho, tau)
  rm(parm_vec, probit_output)
  rm(param_out, xbar, qhat, slope, cons)
  rm(dat_g, dat_condit)
}  #end loop over m in 0,1
  
main_est=as.data.frame(main_est)
main_est=main_est%>%
  na.omit()

#save to csv file 
write.csv(main_est,
          file = file.path("..", "data", "intermediate", "main_est.csv"),
          row.names = FALSE)


#################################################
#Step 4: Mechanism Decomposition(s)
#################################################

###################
#4a- get mechanism contribution table for method 1
#adding variation one at a time, holding fixed at either m or f
###################

#get male and female in separate vectors 
param_m=main_est%>%
  filter(male==1)%>%
  select(mu, sig, c_bar, rho, tau) 
param_f=main_est%>%
  filter(male==0)%>%
  select(mu, sig, c_bar, rho, tau) 


#get diagnosis rate at baseline and baseline gap
dx_baseline_m=simulateDx(param_vec=param_m, N=n_opt, seed=seed_opt)
dx_baseline_f=simulateDx(param_vec=param_f, N=n_opt, seed=seed_opt)
dx_baseline_gap=dx_baseline_m$dx_rate/dx_baseline_f$dx_rate


############################################
#For each iteration, get the dx gap, the additional gap reduction, and the percent gap reduction at both male estimates fixed and female estimates fixed. 
#iter1- allowing mu and sig at true values 
#iter2- allowing mu, sig, and cbar at true values 
#iter3- allowing mu, sig, cbar, and rho at true 
#iter 4 is baseline (all at true values)

#build decomp table to be filled 
#first column is the iteration number (male then female)
#second column is the indicator for at_male 
#third is the diagnosis rate gap 
#fourth is the added effect 
#fifth  is the % contribution 
decomp_table=matrix(0,nrow=8,ncol=5)
decomp_table[,1]=c(1,1,2,2,3,3,4,4)
decomp_table[,2]=rep(c(1,0),4)
decomp_table=as.data.frame(decomp_table)
colnames(decomp_table)=c("iter","at_male","diag_dif","added_eff","contribute_eff")

#for each iteration, go through and define the diagnostic difference in column 3
#note, for iter4, it is a special case with all at true value 
for(i in 1:nrow(decomp_table)){
  if(decomp_table$iter[i]!=4){
    if(decomp_table$at_male[i]==1){
      param_sim_m=param_m
      param_sim_f=c(param_f[1:(1 + decomp_table$iter[i])], 
                    param_m[(2+decomp_table$iter[i]):5])
    }else{
      param_sim_m=c(param_m[1:(1 + decomp_table$iter[i])], 
                    param_f[(2+decomp_table$iter[i]):5])
      param_sim_f=param_f
    }
  } else{
    param_sim_m=param_m
    param_sim_f=param_f
  }
  
  decomp_table$diag_dif[i]=simulateDx(param_vec=param_sim_m, N=n_opt, seed=seed_opt)$dx_rate/simulateDx(param_vec=param_sim_f, N=n_opt, seed=seed_opt)$dx_rate
}

#now edit the added effect and the contribution effect 
#note that value at iter=1 is special case with the baseline gap of 1 
decomp_table=decomp_table%>%
  group_by(at_male)%>%
  mutate(added_eff=ifelse(iter==1, diag_dif-1, diag_dif-lag(diag_dif)))%>%
  ungroup()%>%
  mutate(contribute_eff=added_eff/(dx_baseline_gap-1)*100)


#need to round values before printing table 
decomp_table=decomp_table%>%
  mutate(diag_dif=round(diag_dif, 2),
         added_eff=round(added_eff, 2),
         contribute_eff=paste0(round(contribute_eff),"\\%",""))

#Finally, print the table output 

decomp_table_print=paste(
  "\\begin{tabular}{lccc}",
  "  \\toprule",
  "  & Diagnostic & Added & Relative \\\\",
  "  & \\underline{Difference} & \\underline{Effect} & \\underline{Contribution}\\\\",
  "  \\textbf{No Difference} & \\textbf{1.00} & - & -\\\\",
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textbf{Prevalence Contribution}}\\\\",
  "  \\multicolumn{2}{l}{\\textit{ADHD Risk Distribution: $\\mu_\\theta$ and $\\sigma_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[1], decomp_table$added_eff[1], decomp_table$contribute_eff[1]),
  sprintf("  \\hspace{3mm} at Female estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[2], decomp_table$added_eff[2], decomp_table$contribute_eff[2]),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textbf{Patient Contribution}}\\\\",
  "  \\multicolumn{2}{l}{\\textit{Utilization Costs: $c_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[3], decomp_table$added_eff[3], decomp_table$contribute_eff[3]),
  sprintf("  \\hspace{3mm} at Female estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[4], decomp_table$added_eff[4], decomp_table$contribute_eff[4]),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textbf{Physician Contribution}}\\\\",
  "  \\multicolumn{2}{l}{\\textit{Signal Quality: $\\rho_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[5], decomp_table$added_eff[5], decomp_table$contribute_eff[5]),
  sprintf("  \\hspace{3mm} at Female estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[6], decomp_table$added_eff[6], decomp_table$contribute_eff[6]),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textit{Diagnostic Thresholds: $\\tau_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.2f & %+.2f & %s \\\\",
          decomp_table$diag_dif[7], decomp_table$added_eff[7], decomp_table$contribute_eff[7]),
  sprintf("  \\hspace{3mm} at Female estimates & %.2f & %+.2f & %s\\\\",
          decomp_table$diag_dif[8], decomp_table$added_eff[8], decomp_table$contribute_eff[8]),
  "  \\addlinespace",
  "  \\hline",
  paste(sprintf("  \\textbf{Overall Difference} & \\textbf{%.2f} & %+.2f",  dx_baseline_gap, dx_baseline_gap - 1)," & 100\\%\\\\"),
  "  \\hline",
  "\\end{tabular}",
  sep = "\n"
)
write(decomp_table_print, file = file.path("..", "output", "tables", "tab_7.txt"))

#clean up 
rm(decomp_table, decomp_table_print, param_sim_f, param_sim_m)


###################
#4b- get mechanism contribution table for method 2
#only varying one at a time, holding fixed at either m or f 
###################

########
#For each iteration, get the dx rate for male, for female, and the dx gap, at both male estimates fixed and female estimates fixed. 
#iter1- allowing all but mu and sig at true values 
#iter2- allowing all but cbar at true value 
#iter3- allowing  all but rho at true value
#iter4- allowing all but tau at true value 


#build decomp table to be filled 
#first column is the iteration number (male then female)
#second column is the indicator for at_male 
#third is the diagnosis rate for male 
#fourth is the diagnosis rate for female 
#fifth  is the diagnostic difference (male/female)
decomp_table=matrix(0,nrow=8,ncol=5)
decomp_table[,1]=c(1,1,2,2,3,3,4,4)
decomp_table[,2]=rep(c(1,0),4)
decomp_table=as.data.frame(decomp_table)
colnames(decomp_table)=c("iter","at_male","diag_male","diag_female","diag_gap")

#for each iteration, go through and fill in the diag_male, diag_female
for(i in 1:nrow(decomp_table)){
  param_sim_m=param_m
  param_sim_f=param_f
  
  #if at_male, assign the male value to param_sim_f
  #if not, assign the female value to param_sim_m
  if(decomp_table$at_male[i]==1){
    param_sim_f[(decomp_table$iter[i]+1)]=param_m[(decomp_table$iter[i]+1)]
  }else{
    param_sim_m[(decomp_table$iter[i]+1)]=param_f[(decomp_table$iter[i]+1)]
  }
  
  #if iter=1, we need to also assign the first value mu 
  if(decomp_table$iter[i]==1){
    if(decomp_table$at_male[i]==1){
      param_sim_f[1]=param_m[1]
    }else{
      param_sim_m[1]=param_f[1]
    }
  }
  decomp_table$diag_male[i]=simulateDx(param_vec=param_sim_m, N=n_opt, seed=seed_opt)$dx_rate
  decomp_table$diag_female[i]=simulateDx(param_vec=param_sim_f, N=n_opt, seed=seed_opt)$dx_rate
} 


#now edit the diag gap 
decomp_table=decomp_table%>%
  mutate(diag_gap=diag_male/diag_female)


#Finally, print the table output 

decomp_table_print2=paste(
  "\\begin{tabular}{lccc}",
  "  \\toprule",
  " & \\multicolumn{2}{c}{Diagnosis Rates} & Diagnostic \\\\",
  " & Male & Female & \\underline{Difference} \\\\ ", 
  " \\cmidrule(lr){2-3}", 
  sprintf("  \\textbf{Baseline Difference} & \\textbf{%.3f} & \\textbf{%.3f} & \\textbf{%.2f}\\\\",
          dx_baseline_m$dx_rate, dx_baseline_f$dx_rate, dx_baseline_gap),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textbf{Prevalence Contribution}}\\\\",
  "  \\multicolumn{2}{l}{\\textit{ADHD Risk Distribution: $\\mu_\\theta$ and $\\sigma_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[1], decomp_table$diag_female[1], decomp_table$diag_gap[1]),
  sprintf("  \\hspace{3mm} at Female estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[2], decomp_table$diag_female[2], decomp_table$diag_gap[2]),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textbf{Patient Contribution}}\\\\",
  "  \\multicolumn{2}{l}{\\textit{Utilization Costs: $c_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[3], decomp_table$diag_female[3], decomp_table$diag_gap[3]),
  sprintf("  \\hspace{3mm} at Female estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[4], decomp_table$diag_female[4], decomp_table$diag_gap[4]),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textbf{Physician Contribution}}\\\\",
  "  \\multicolumn{2}{l}{\\textit{Signal Quality: $\\rho_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[5], decomp_table$diag_female[5], decomp_table$diag_gap[5]),
  sprintf("  \\hspace{3mm} at Female estimates& %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[6], decomp_table$diag_female[6], decomp_table$diag_gap[6]),
  "  \\addlinespace",
  "  \\multicolumn{2}{l}{\\textit{Diagnostic Thresholds: $\\tau_\\theta$}}\\\\",
  sprintf("  \\hspace{3mm} at Male estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[7], decomp_table$diag_female[7], decomp_table$diag_gap[7]),
  sprintf("  \\hspace{3mm} at Female estimates & %.3f & %.3f & %.2f \\\\",
          decomp_table$diag_male[8], decomp_table$diag_female[8], decomp_table$diag_gap[8]),
  "  \\hline",
  "\\end{tabular}",
  sep = "\n"
)
write(decomp_table_print2, file = file.path("..", "output", "tables", "tab_a7.txt"))

#################################################
#Step 5: Output the Moment comparison table (A6)
#################################################

#first get the observed values 
obs_total = step2_dat %>%
  summarise(
    D  = mean(Di),
    Q  = mean(Qi),
    xQ = mean(xi[Qi == 1]),
    DQ = mean(Di[Qi == 1])
  )
obs_male = step2_dat %>%
  filter(male == 1) %>%
  summarise(
    D  = mean(Di),
    Q  = mean(Qi),
    xQ = mean(xi[Qi == 1]),
    DQ = mean(Di[Qi == 1])
  )
obs_female = step2_dat %>%
  filter(male == 0) %>%
  summarise(
    D  = mean(Di),
    Q  = mean(Qi),
    xQ = mean(xi[Qi == 1]),
    DQ = mean(Di[Qi == 1])
  )

#then get the simulated moments 
moment_m=simulateDx(param_vec=param_m, N=n_opt, seed=seed_opt)
moment_f=simulateDx(param_vec=param_f, N=n_opt, seed=seed_opt)

sim_total = c(
  D  = (moment_m$dx_sum + moment_f$dx_sum) / (2 * n_opt),
  Q  = (moment_m$q_sum  + moment_f$q_sum)  / (2 * n_opt),
  xQ = (moment_m$xsum_c + moment_f$xsum_c) / (moment_m$q_sum + moment_f$q_sum),
  DQ = (moment_m$dx_sum + moment_f$dx_sum) / (moment_m$q_sum + moment_f$q_sum)
)

sim_male   = c(D = moment_m$dx_rate, Q = moment_m$q_rate, xQ = moment_m$xbar_c, DQ = moment_m$dx_c)
sim_female = c(D = moment_f$dx_rate, Q = moment_f$q_rate, xQ = moment_f$xbar_c, DQ = moment_f$dx_c)

##make the table 
match_moment_tab = paste(
  "\\begin{tabular}{lcccccc}",
  "  \\toprule",
  "  & \\multicolumn{3}{c}{Observed} & \\multicolumn{3}{c}{Simulated}\\\\",
  "  \\cmidrule(lr){2-4}\\cmidrule(lr){5-7}",
  "  & Total & Male & Female & Total & Male & Female \\\\",
  "  \\midrule",
  sprintf("  ADHD Dx. ($D$) & %.3f & %.3f & %.3f & %.3f & %.3f & %.3f \\\\",
          obs_total$D, obs_male$D, obs_female$D,
          sim_total['D'], sim_male['D'], sim_female['D']),
  sprintf("  Behavioral Appt. ($Q$) & %.3f & %.3f & %.3f & %.3f & %.3f & %.3f \\\\",
          obs_total$Q, obs_male$Q, obs_female$Q,
          sim_total['Q'], sim_male['Q'], sim_female['Q']),
  sprintf("  ADHD match ($x|Q$) & %.3f & %.3f & %.3f & %.3f & %.3f & %.3f \\\\",
          obs_total$xQ, obs_male$xQ, obs_female$xQ,
          sim_total['xQ'], sim_male['xQ'], sim_female['xQ']),
  sprintf("  Cond. Dx. ($D|Q$) & %.3f & %.3f & %.3f & %.3f & %.3f & %.3f \\\\",
          obs_total$DQ, obs_male$DQ, obs_female$DQ,
          sim_total['DQ'], sim_male['DQ'], sim_female['DQ']),
  "  \\bottomrule",
  "\\end{tabular}",
  sep = "\n"
)

write(match_moment_tab, file = file.path("..", "output", "tables", "tab_a6.txt"))

#clean up 
rm(match_moment_tab, obs_total, obs_male, obs_female, sim_total, sim_male, sim_female)

####################################################################################


#################################################
#Step 6: Output the Same Threshold table (table 8)
#################################################

#first get rates from data: dx_male, dx_female
dx_data=c(mean(step2_dat$Di[step2_dat$male == 1]), 
          mean(step2_dat$Di[step2_dat$male == 0]))

#then get rates from simulated at baseline 
dx_base=c(simulateDx(param_vec=param_m, N=n_opt, seed=seed_opt)$dx_rate,
             simulateDx(param_vec=param_f, N=n_opt, seed=seed_opt)$dx_rate)

#now set tau=tau_m
dx_tm=c(simulateDx(param_vec=param_m, N=n_opt, seed=seed_opt)$dx_rate,
           simulateDx(param_vec=c(param_f[1:4],param_m[5]), N=n_opt, seed=seed_opt)$dx_rate)

#now set tau=tau_f
dx_tf=c(simulateDx(param_vec=c(param_m[1:4],param_f[5]), N=n_opt, seed=seed_opt)$dx_rate,
           simulateDx(param_vec=param_f, N=n_opt, seed=seed_opt)$dx_rate)


# Make the table 
tab_tau = paste(
  "\\begin{tabular}{lccc}",
  "  \\toprule",
  "  & Male & Female & Diagnostic\\\\",
  "  & Diagnosis Rate & Diagnosis Rate & Gender Gap \\\\",
  "  \\midrule",
  sprintf("  Data & %.1f\\%% & %.1f\\%% & %.2f \\\\",
          dx_data[1]*100, dx_data[2]*100, dx_data[1]/dx_data[2]),
  sprintf("  Baseline & %.1f\\%% & %.1f\\%% & %.2f \\\\",
          dx_base[1]*100, dx_base[2]*100, dx_base[1]/dx_base[2]),
  "  \\midrule",
  sprintf("  $\\tau_f=\\widehat{\\tau_m}=%.3f$ & %.1f\\%% & %.1f\\%% & %.2f \\\\",
          param_m[5], dx_tm[1]*100, dx_tm[2]*100, dx_tm[1]/dx_tm[2]),
  sprintf("  $\\tau_m=\\widehat{\\tau_f}=%.3f$ & %.1f\\%% & %.1f\\%% & %.2f \\\\",
          param_f[5], dx_tf[1]*100, dx_tf[2]*100, dx_tf[1]/dx_tf[2]),  
  "  \\bottomrule",
  "\\end{tabular}",
  sep = "\n"
)


####
#save this to output file 
write(tab_tau, file = file.path("..", "output", "tables", "tab_8.txt"))

#END OF SCRIPT

